package resolvers

import (
	"context"

	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/set"
	"github.com/stackrox/rox/pkg/utils"
)

func init() {
	schema := getBuilder()
	utils.Must(
		schema.AddType("VulnerabilityCounter", []string{
			"all: VulnerabilityFixableCounterResolver!",
			"low: VulnerabilityFixableCounterResolver!",
			"moderate: VulnerabilityFixableCounterResolver!",
			"important: VulnerabilityFixableCounterResolver!",
			"critical: VulnerabilityFixableCounterResolver!",
		}),
		schema.AddType("VulnerabilityFixableCounterResolver", []string{
			"total: Int!",
			"fixable: Int!",
		}),
	)
}

// VulnerabilityCounterResolver returns the counts of vulnerabilities in a couple different buckets.
type VulnerabilityCounterResolver struct {
	all       *VulnerabilityFixableCounterResolver
	low       *VulnerabilityFixableCounterResolver
	moderate  *VulnerabilityFixableCounterResolver
	important *VulnerabilityFixableCounterResolver
	critical  *VulnerabilityFixableCounterResolver
}

// All returns the counter for vulnerabilities of all severity levels.
func (evr *VulnerabilityCounterResolver) All(_ context.Context) *VulnerabilityFixableCounterResolver {
	return evr.all
}

// Low returns the number of low impact vulnerabilities
func (evr *VulnerabilityCounterResolver) Low(_ context.Context) *VulnerabilityFixableCounterResolver {
	return evr.low
}

// Moderate returns the number of moderate impact vulnerabilities
func (evr *VulnerabilityCounterResolver) Moderate(_ context.Context) *VulnerabilityFixableCounterResolver {
	return evr.moderate
}

// Important returns the number of important impact vulnerabilities
func (evr *VulnerabilityCounterResolver) Important(_ context.Context) *VulnerabilityFixableCounterResolver {
	return evr.important
}

// Critical returns the number of critical vulnerabilities
func (evr *VulnerabilityCounterResolver) Critical(_ context.Context) *VulnerabilityFixableCounterResolver {
	return evr.critical
}

// VulnerabilityFixableCounterResolver is a counter that differentiates between fixable and all vulnerabilities.
type VulnerabilityFixableCounterResolver struct {
	total   int32
	fixable int32
}

// Total returns the total number of vulnerabilities
func (evr *VulnerabilityFixableCounterResolver) Total(_ context.Context) int32 {
	return evr.total
}

// Fixable returns the number of fixable vulnerabilities
func (evr *VulnerabilityFixableCounterResolver) Fixable(_ context.Context) int32 {
	return evr.fixable
}

// VulnerabilityWithSeverity provides functionality to fetch vulnerability severity.
type VulnerabilityWithSeverity interface {
	GetId() string
	GetSeverity() storage.VulnerabilitySeverity
}

// Static helpers.
//////////////////

func emptyVulnerabilityCounter() *VulnerabilityCounterResolver {
	return &VulnerabilityCounterResolver{
		all:       &VulnerabilityFixableCounterResolver{},
		low:       &VulnerabilityFixableCounterResolver{},
		moderate:  &VulnerabilityFixableCounterResolver{},
		important: &VulnerabilityFixableCounterResolver{},
		critical:  &VulnerabilityFixableCounterResolver{},
	}
}

func mapVulnsToVulnerabilityCounter(vulns []*storage.EmbeddedVulnerability) *VulnerabilityCounterResolver {
	counter := emptyVulnerabilityCounter()
	for _, vuln := range vulns {
		if vuln.GetSuppressed() {
			continue
		}
		isFixable := vuln.GetFixedBy() != ""
		counter.all.total++
		if isFixable {
			counter.all.fixable++
		}
		incCounterBySev(counter, vuln.GetSeverity(), isFixable)
	}
	return counter
}

func mapCVEsToVulnerabilityCounter(fixable, unFixable []VulnerabilityWithSeverity) *VulnerabilityCounterResolver {
	counter := emptyVulnerabilityCounter()
	seenVulns := set.NewStringSet()
	for _, vuln := range fixable {
		counter.all.total++
		counter.all.fixable++
		seenVulns.Add(vuln.GetId())
		incCounterBySev(counter, vuln.GetSeverity(), true)
	}

	for _, vuln := range unFixable {
		if !seenVulns.Contains(vuln.GetId()) {
			counter.all.total++
			incCounterBySev(counter, vuln.GetSeverity(), false)
		}
	}
	return counter
}

func incCounterBySev(counter *VulnerabilityCounterResolver, sev storage.VulnerabilitySeverity, fixable bool) {
	switch sev {
	case storage.VulnerabilitySeverity_LOW_VULNERABILITY_SEVERITY, storage.VulnerabilitySeverity_UNKNOWN_VULNERABILITY_SEVERITY:
		counter.low.total++
		if fixable {
			counter.low.fixable++
		}
	case storage.VulnerabilitySeverity_MODERATE_VULNERABILITY_SEVERITY:
		counter.moderate.total++
		if fixable {
			counter.moderate.fixable++
		}
	case storage.VulnerabilitySeverity_IMPORTANT_VULNERABILITY_SEVERITY:
		counter.important.total++
		if fixable {
			counter.important.fixable++
		}
	case storage.VulnerabilitySeverity_CRITICAL_VULNERABILITY_SEVERITY:
		counter.critical.total++
		if fixable {
			counter.critical.fixable++
		}
	}
}
